"""
    Title

    author: wxz
    date:
    github: https://github.com/xinzwang
"""

import numpy as np


def channel_gain_white_balance(img, channel_gain=(1.9296875, 1.0, 2.26171875)):
    # create new data
    data = np.zeros(img.shape)

    # multiply with the channel gains
    data[:, :, 0] = img[:, :, 0] * channel_gain[0]
    data[:, :, 1] = img[:, :, 1] * channel_gain[1]
    data[:, :, 2] = img[:, :, 2] * channel_gain[2]

    # clipping within range
    data = np.clip(data, 0., 1)

    return data


def perfect_reflector_white_balance(img, ratio=0.10):
    height, width, _ = img.shape

    # 按通道求和
    sum = np.sum(img, axis=2)

    # 最大点数值
    max_val = np.max(sum)

    # 画直方图
    hist = np.histogram(sum.flatten(), 766, [0, 1])

    # 统计最亮ratio点的 亮度和 分界点
    num_ratio = img.shape[0] * img.shape[1] * ratio

    rgbsum = 0
    counter = 0
    threshold = 0
    for i in range(765, -1, -1):
        rgbsum += hist[0][i]  # 颜色和
        if rgbsum > num_ratio:
            threshold = hist[1][i]  # 分界点
            break
        counter += hist[0][i]

    sum_R = 0
    sum_G = 0
    sum_B = 0

    for i in range(height):
        for j in range(width):
            t = img[i, j, 0] + img[i, j, 1] + img[i, j, 2]
            if t > threshold:
                sum_R += img[i, j, 0]
                sum_G += img[i, j, 1]
                sum_B += img[i, j, 2]

    Raver = sum_R / rgbsum
    Gaver = sum_G / rgbsum
    Baver = sum_B / rgbsum

    kr = max_val / Raver
    kg = max_val / Gaver
    kb = max_val / Baver

    out = np.zeros_like(img)

    out[:, :, 0] = img[:, :, 0] * kr
    out[:, :, 1] = img[:, :, 1] * kg
    out[:, :, 2] = img[:, :, 2] * kb

    out = np.clip(out, 0, 1)
    return out


def gray_world_white_balance(img):
    R_ave = np.average(img[:, :, 0])
    G_ave = np.average(img[:, :, 1])
    B_ave = np.average(img[:, :, 2])

    K = (R_ave + G_ave + B_ave) / 3
    Kr, Kg, Kb = K / R_ave, K / G_ave, K / B_ave

    out = np.zeros_like(img)

    out[:, :, 0] = (img[:, :, 0] * Kr)
    out[:, :, 1] = (img[:, :, 1] * Kg)
    out[:, :, 2] = (img[:, :, 2] * Kb)

    out = np.clip(out, 0, 1)

    return out
